import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
from torch.func import functional_call, vmap, grad
import numpy as np
import copy
import random

#bootstrap_bs = 5
#BatchNum = 100
#NumSample = 20
#B = 50
seed = 2
npseed = 6
torchseed = 5

def LossScaledTrace(test_model, train_data, d, train_size, B=3200):
    #torch.manual_seed(seed)
    model = copy.deepcopy(test_model)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    #model = nn.DataParallel(model)
    FullGradient = torch.zeros(d, 1)
    FullGradient = FullGradient.to(device)
    #Gradients = torch.zeros(BatchNum, d * 2)
    CovarianceMatrix = torch.zeros(d, d)
    CovarianceMatrix = CovarianceMatrix.to(device)
    Hessian = torch.zeros(d, d)
    Hessian = Hessian.to(device)
    FullLoss = 0
    #FullLoss = FullLoss.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1)

    # Compute the per-sample gradients and losses
    optimizer.zero_grad()
    params = {k: v.detach() for k, v in model.named_parameters()}
    buffers = {k: v.detach() for k, v in model.named_buffers()}

    def compute_loss(params, buffers, sample, target):
        batch = sample.unsqueeze(0)
        targets = target.unsqueeze(0)

        predictions = functional_call(model, (params, buffers), (batch,))
        loss = criterion(predictions, targets)
        return loss

    ft_compute_grad = grad(compute_loss)
    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
    #print('Data Type is')
    #print(train_data.dataset.data)
    #print(train_data.dataset.data.to(torch.float))
    ft_per_sample_grads = ft_compute_sample_grad(params, buffers, train_data.dataset.data.to(torch.float), train_data.dataset.targets.to(torch.float))
    print(ft_per_sample_grads)
    print('Size is')
    print(ft_per_sample_grads.size())


    for i in range(ft_per_sample_grads.size()[0]):
        FullLoss += 1#loss.item()
        Gradient = torch.transpose(torch.cat((model.module.conv1[0].weight.grad.view(1, -1), model.module.conv1[0].bias.grad.view(1, -1), model.module.conv2[0].weight.grad.view(1, -1)
                              , model.module.conv2[0].bias.grad.view(1, -1), model.module.out.weight.grad.view(1, -1), model.module.out.bias.grad.view(1, -1)),1), 0, 1)
        #print("The gradient is {}".format(Gradient))
        #print("The size of the gradient is {}".format(Gradient.size()))
        FullGradient += Gradient
        Hessian += torch.mm(torch.reshape(Gradient, (d, 1)), torch.reshape(Gradient, (1, d))) / max([(loss.item() * 2), 10e-15])
        #print("The size of the Hessian is {}".format(Hessian.size()))
        CovarianceMatrix += torch.mm(torch.reshape(Gradient, (d, 1)), torch.reshape(Gradient, (1, d)))
        optimizer.zero_grad()
        #print(torch.matmul(Gradients[idx], torch.transpose(Gradients[idx])).size())
    FullGradient = FullGradient / B
    Hessian = Hessian / B
    CovarianceMatrix = CovarianceMatrix / B
    CovarianceMatrix -= torch.mm(torch.reshape(FullGradient, (d, 1)), torch.reshape(FullGradient, (1, d)))
    FullLoss =  FullLoss / B
    print(type(torch.trace(Hessian)))
    return torch.trace(torch.mm(Hessian, CovarianceMatrix)).cpu().numpy() / max([(FullLoss * 2), 10e-15]), \
           torch.norm(Hessian, p='fro').cpu().numpy() / 1, torch.trace(Hessian).cpu().numpy() / 1
